// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/host/math/beam_search.h"
#include <cmath>
#include <string>
#include <vector>

namespace paddle {
namespace lite {
namespace host {
namespace math {
/*
* The basic items help to sort.
*/
struct Item {
  Item() {}
  Item(size_t offset, size_t id, float score)
      : offset(offset), id(id), score(score) {}
  // offset in the higher lod level.
  size_t offset;
  // prefix id in the lower lod level.
  // size_t prefix;
  // the candidate id
  size_t id;
  // the corresponding score
  float score;

  inline bool operator<(const Item &in) const {
    return (score < in.score) || ((score == in.score) && (offset < in.offset));
  }

  inline void operator=(const Item &in) {
    offset = in.offset;
    id = in.id;
    score = in.score;
  }

  std::string ToString() {
    std::ostringstream os;
    os << "{";
    os << "offset: " << offset << ", ";
    os << "id: " << id << ", ";
    os << "score: " << score << "";
    os << "}";
    return os.str();
  }
};

/*
 * Prune the source sentences all branchs finished, and it is optional.
 * Pruning must one step later than finishing (thus pre_ids is needed here),
 * since the end tokens must be writed out.
 */
void PruneEndBeams(const Tensor *pre_ids,
                   const LoD &abs_lod,
                   std::vector<std::vector<Item>> *items,
                   size_t lod_level,
                   int end_id) {
  auto *pre_ids_data = pre_ids->data<int64_t>();
  auto &high_level = abs_lod[lod_level];
  for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) {
    size_t src_prefix_start = high_level[src_idx];
    size_t src_prefix_end = high_level[src_idx + 1];
    bool finish_flag = true;
    for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) {
      for (auto &item : items->at(offset)) {
        if (item.id != static_cast<size_t>(end_id) ||
            pre_ids_data[offset] != end_id) {
          finish_flag = false;
          break;
        }
      }
      if (!finish_flag) break;
    }
    if (finish_flag) {  // all branchs of the beam (source sentence) end and
                        // prune this beam
      for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++)
        items->at(offset).clear();
    }
  }
}

/*
 * Transform the items into a map whose key is offset, value is the items.
 * NOTE low performance.
 */
std::vector<std::vector<Item>> ToMap(
    const std::vector<std::vector<Item>> &items, size_t element_num) {
  std::vector<std::vector<Item>> result;
  result.resize(element_num);
  for (auto &entries : items) {
    for (const auto &item : entries) {
      result[item.offset].push_back(item);
    }
  }
  return result;
}

void Insert(std::vector<Item> *top_beam_ptr,
            const Item &item,
            size_t beam_size) {
  std::vector<Item> &top_beam = *top_beam_ptr;

  size_t num_beams = top_beam.size();
  if (num_beams < beam_size) {
    top_beam.resize(num_beams + 1);
    num_beams++;
  } else {
    if (item < top_beam[beam_size - 1]) {
      return;
    }
  }

  for (int k = static_cast<int>(num_beams) - 2; k >= 0; --k) {
    if (top_beam[k] < item) {
      top_beam[k + 1] = top_beam[k];
    } else {
      top_beam[k + 1] = item;
      return;
    }
  }
  top_beam[0] = item;
}

/*
 * For each source, select top beam_size records.
 */
std::vector<std::vector<Item>> SelectTopBeamSizeItems(const Tensor *pre_ids,
                                                      const Tensor *pre_scores,
                                                      const Tensor *ids,
                                                      const Tensor *scores,
                                                      size_t lod_level,
                                                      size_t beam_size,
                                                      int end_id,
                                                      bool is_accumulated) {
  std::vector<std::vector<Item>> result;

  // find the current candidates
  // auto abs_lod = framework::ToAbsOffset(scores->lod());
  auto abs_lod = scores->lod();
  auto *pre_ids_data = pre_ids->data<int64_t>();
  auto *pre_scores_data = pre_scores->data<float>();

  auto *ids_data = ids ? ids->data<int64_t>() : nullptr;
  auto *scores_data = scores->data<float>();

  size_t num_seqs = abs_lod[lod_level].size() - 1;
  size_t seq_width = 1;
  for (int i = 1; i < scores->dims().size(); i++) {
    seq_width *= scores->dims()[i];
  }

  for (size_t seq_id = 0; seq_id < num_seqs; ++seq_id) {
    size_t seq_offset_start = abs_lod[lod_level][seq_id];
    size_t seq_offset_end = abs_lod[lod_level][seq_id + 1];

    std::vector<Item> top_beam;
    top_beam.reserve(beam_size);

    for (size_t offset = seq_offset_start; offset < seq_offset_end; ++offset) {
      auto pre_id = pre_ids_data[offset];
      auto pre_score = pre_scores_data[offset];
      if (pre_id == end_id) {
        // Allocate all probability mass to end_id for finished branchs and
        // the other candidate ids can be ignored.
        Item item(offset, end_id, pre_score);
        Insert(&top_beam, item, beam_size);
      } else {
        size_t index = offset * seq_width;
        for (size_t d = 0; d < seq_width; d++, index++) {
          int64_t id = ids_data ? ids_data[index] : static_cast<int64_t>(d);
          float score = is_accumulated
                            ? scores_data[index]
                            : pre_score + std::log(scores_data[index]);
          Item item(offset, id, score);
          Insert(&top_beam, item, beam_size);
        }
      }
    }

    result.emplace_back(top_beam);
  }
  return result;
}

void beam_search(const Tensor *pre_ids,
                 const Tensor *pre_scores,
                 const Tensor *ids,
                 const Tensor *scores,
                 Tensor *selected_ids,
                 Tensor *selected_scores,
                 Tensor *parent_idx,
                 int level,
                 int beam_size,
                 int end_id,
                 bool is_accumulated) {
  auto abs_lod = scores->lod();
  auto &high_level = abs_lod[level];
  auto items = SelectTopBeamSizeItems(pre_ids,
                                      pre_scores,
                                      ids,
                                      scores,
                                      level,
                                      beam_size,
                                      end_id,
                                      is_accumulated);
  auto selected_items = ToMap(items, high_level.back());

  PruneEndBeams(pre_ids, abs_lod, &selected_items, level, end_id);
  // calculate the output tensor's height
  size_t num_instances = std::accumulate(
      std::begin(selected_items),
      std::end(selected_items),
      0,
      [](size_t a, std::vector<Item> &b) { return a + b.size(); });
  // the output tensor shape should be [num_instances, 1]
  auto dims = std::vector<int64_t>({static_cast<int>(num_instances), 1});
  selected_ids->Resize(dims);
  selected_scores->Resize(dims);
  if (parent_idx) {
    parent_idx->Resize({static_cast<int64_t>(num_instances)});
  }
  auto *selected_ids_data = selected_ids->mutable_data<int64_t>();
  auto *selected_scores_data = selected_scores->mutable_data<float>();
  auto *parent_idx_data =
      parent_idx ? parent_idx->mutable_data<int>() : nullptr;

  // fill in data
  std::vector<size_t> low_level;
  size_t low_offset = 0;
  for (auto &items : selected_items) {
    low_level.push_back(low_offset);
    for (auto &item : items) {
      if (parent_idx) {
        parent_idx_data[low_offset] = static_cast<int>(low_level.size() - 1);
      }
      selected_ids_data[low_offset] = item.id;
      selected_scores_data[low_offset] = item.score;
      low_offset++;
    }
  }
  low_level.push_back(low_offset);

  // fill lod
  LoD lod(2);
  lod[0].assign(high_level.begin(), high_level.end());
  lod[1].assign(low_level.begin(), low_level.end());
  *(selected_ids->mutable_lod()) = lod;
  *(selected_scores->mutable_lod()) = lod;
}

}  // namespace math
}  // namespace host
}  // namespace lite
}  // namespace paddle
